{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Script to apply fine-tuned DistilBERT classifier\n",
    "## to full set of documents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/smueller/miniconda3/envs/python-3.11.5/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "## load general libraries\n",
    "import pandas as pd\n",
    "## load relevant functions from transformers library\n",
    "from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline, pipeline\n",
    "## load timeit for timing the pipeline\n",
    "import time\n",
    "\n",
    "## load pyreadr library to open .rds file\n",
    "import pyreadr\n",
    "\n",
    "## load platform to detect os\n",
    "import platform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# detect OS\n",
    "current_os = platform.system()\n",
    "# active device based on OS\n",
    "if current_os == \"Darwin\":\n",
    "    # specify device as mps\n",
    "    device = \"mps\"\n",
    "else:\n",
    "    # check if gpu is available, if yes use cuda, if not stick to cpu\n",
    "    device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "    # must be 'cuda:0', not just 'cuda' due to a bug in transformers library, see: https://github.com/deepset-ai/haystack/issues/3160"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mps\n"
     ]
    }
   ],
   "source": [
    "# print device\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# specify model\n",
    "\n",
    "num_labels = 2\n",
    "id2label = {\n",
    "    0: \"Non-housing\",\n",
    "    1: \"Housing\",\n",
    "}\n",
    "label2id = {\n",
    "    \"Non-housing\": 0,\n",
    "    \"Housing\": 1,\n",
    "}\n",
    "\n",
    "# load fine-tuned model\n",
    "model_housing = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-uncased-finetuned-housing\",\n",
    "                                                                   num_labels=num_labels, id2label=id2label, label2id=label2id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load tokenizer (DistilBERT)\n",
    "tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "908895\n"
     ]
    }
   ],
   "source": [
    "# load the following file using read_parquet from pandas\n",
    "path = \"data_dontshare/data_analysis.parquet\"\n",
    "\n",
    "# read as parquet file\n",
    "data_analysis = pd.read_parquet(path)\n",
    "\n",
    "# get number of rows of data frame\n",
    "n = data_analysis.shape[0]\n",
    "print(n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0     adams gerry_2018_tweets_1\n",
      "1     adams gerry_2018_tweets_2\n",
      "2     adams gerry_2018_tweets_3\n",
      "3     adams gerry_2018_tweets_4\n",
      "4     adams gerry_2018_tweets_5\n",
      "5     adams gerry_2018_tweets_6\n",
      "6     adams gerry_2018_tweets_7\n",
      "7     adams gerry_2018_tweets_8\n",
      "8     adams gerry_2018_tweets_9\n",
      "9    adams gerry_2018_tweets_10\n",
      "Name: document_id, dtype: object\n"
     ]
    }
   ],
   "source": [
    "# get the first 10 rows of the variable called \"document_id\" in data_analysis\n",
    "print(data_analysis[\"document_id\"].head(10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0    Uachtar reoite go deo! Yum Yum! https://t.co/t...\n",
      "1            The Three Amigos. https://t.co/DhCTTcOFmk\n",
      "2      Oiche mhaith. Xoxozzzxx https://t.co/cvHiuudrPT\n",
      "3                   Bígí Linn. https://t.co/gxTRshPwuQ\n",
      "4    No light @ the end of that tunnel. https://t.c...\n",
      "5    The excuses given by Taoiseach in the Dáil for...\n",
      "6    LEADERS DEBATE - Féile an Phobail https://t.co...\n",
      "7    Oiche mhaith apres Lá fada. XoxHallelujahzzz h...\n",
      "8    So it’s down to the Shoot Out! Mná na hÉireann...\n",
      "9              Up 4 An Fleadh! https://t.co/uiiPIfURkH\n",
      "Name: text, dtype: object\n"
     ]
    }
   ],
   "source": [
    "# inspect the text column\n",
    "print(data_analysis[\"text\"].head(10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# specify classification pipeline\n",
    "classify = pipeline(\"text-classification\",\n",
    "                    model = model_housing,\n",
    "                    max_length = 512,\n",
    "                    batch_size = 64,\n",
    "                    device = device,\n",
    "                    padding = 'max_length',\n",
    "                    truncation = True,\n",
    "                    tokenizer = tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "908895"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# use the following line to get the data from the dataframe\n",
    "text_data = data_analysis['text'].tolist()\n",
    "\n",
    "# length of text_data\n",
    "len(text_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The pipeline call took 183.64 minutes.\n"
     ]
    }
   ],
   "source": [
    "# start the timer\n",
    "start_time = time.time()\n",
    "\n",
    "# get sentence predictions and store them into a dictionary\n",
    "dat_classified = pd.DataFrame(classify(text_data))\n",
    "\n",
    "end_time = time.time()\n",
    "elapsed_time = end_time - start_time\n",
    "\n",
    "# print the time it took to run the pipeline\n",
    "print(f\"The pipeline call took {elapsed_time/60:.2f} minutes.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Non-housing</td>\n",
       "      <td>0.984144</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Non-housing</td>\n",
       "      <td>0.984707</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Non-housing</td>\n",
       "      <td>0.984576</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Non-housing</td>\n",
       "      <td>0.984634</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Non-housing</td>\n",
       "      <td>0.983335</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         label     score\n",
       "0  Non-housing  0.984144\n",
       "1  Non-housing  0.984707\n",
       "2  Non-housing  0.984576\n",
       "3  Non-housing  0.984634\n",
       "4  Non-housing  0.983335"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# inspect predictions\n",
    "dat_classified.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['document_id', 'text', 'year', 'fulldate', 'harmonised', 'ownership',\n",
       "       'ownership_clean', 'ownership_3', 'ownership_2', 'ownership_2_lag',\n",
       "       'propnum', 'election_year', 'seniority', 'constituency_name',\n",
       "       'dublin_dummy', 'party', 'party_recoded', 'party_broad', 'name',\n",
       "       'author_id', 'id', 'in_reply_to_user_id', 'created_at',\n",
       "       'conversation_id', 'lang', 'public_metrics', 'date', 'mpname', 'land',\n",
       "       'candidate', 'first_pref_share', 'gender', 'district_magnitude_man',\n",
       "       'running_sum', 'elected_sum', 'term', 'housing_committee', 'info',\n",
       "       'housing_pos', 'county_recoded', 'n_houses', 'mean_price',\n",
       "       'median_price', 'mean_price_lag', 'median_price_lag',\n",
       "       'perc_change_mean', 'perc_change_median', 'change_ownership',\n",
       "       'change_became_owner', 'gov_opp', 'type', 'title', 'respondent',\n",
       "       'birth', 'label', 'score'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# add all variables from data_analysis to dat_classified\n",
    "dat_classified = pd.concat([data_analysis, dat_classified], axis = 1)\n",
    "\n",
    "# get names of variables in dat_classified\n",
    "dat_classified.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>col_0</th>\n",
       "      <th>count</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>housing_bert</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>Housing</th>\n",
       "      <td>53231</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Non-housing</th>\n",
       "      <td>855664</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "col_0          count\n",
       "housing_bert        \n",
       "Housing        53231\n",
       "Non-housing   855664"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# rename label column in dat_classified to \"housing_bert\"\n",
    "dat_classified = dat_classified.rename(columns = {\"label\": \"housing_bert\"})\n",
    "\n",
    "# create cross-table of housing_bert\n",
    "pd.crosstab(index = dat_classified['housing_bert'], columns = \"count\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# write output as a parquet file\n",
    "dat_classified.to_parquet(\"data_dontshare/data_analysis_classified_housing.parquet\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python-3.11.5",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
